package cn.js189.gateway.filter;

import cn.hutool.core.text.CharSequenceUtil;
import cn.js189.common.util.ServletUtils;
import cn.js189.gateway.config.SqlWhiteProperties;
import cn.js189.gateway.utils.ReqUtils;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import javax.annotation.Resource;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * SQL注入过滤器
 */
@Slf4j
@Component
public class SqlInjectionFilter implements GlobalFilter, Ordered  {
	
	@Resource
	private SqlWhiteProperties sqlWhiteProperties;
	
	/**
	 * SQL注入正则判断
	 */
	private static final String BAD_STR_REQ = "\b(and|exec|insert|select|drop|grant|alter|delete|update|count|chr|mid|master|truncate|char|declare|or)\b|([*;+'%])";
	
	/**
	 * 整体都忽略大小写
	 */
	private static final Pattern sqlPattern = Pattern.compile(BAD_STR_REQ, Pattern.CASE_INSENSITIVE);
	
	@Override
	public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
		// 获取请求路径、请求方法、请求参数
		ServerHttpRequest request = exchange.getRequest();
		String requestPath = request.getPath().toString();
		HttpMethod method = request.getMethod();
		MediaType contentType = request.getHeaders().getContentType();
		
		// 白名单 不过滤
		String url = request.getURI().getPath();
		if (cn.js189.common.util.StringUtils.matches(url, sqlWhiteProperties.getWhites()))
		{
			return chain.filter(exchange);
		}
		
		// 判断Param是否存在SQL注入
		AtomicBoolean isSqlInjection = new AtomicBoolean(false);
		request.getQueryParams().forEach((key, values) -> {
			for (String value : values) {
				if (StringUtils.hasText(value) && checkSqlInjection(value)) {
					isSqlInjection.set(true);
					return;
				}
			}
		});
		if (isSqlInjection.get()) {
			return errorResponse(exchange,"请求参数中包含不允许sql的关键词，请求拒绝");
		}
		// contentType不为空，一般说明body中有参数，判断body中是否存在SQL注入，如果存在则直接返回错误信息
		if (contentType != null) {
			return DataBufferUtils.join(exchange.getRequest().getBody()).flatMap(dataBuffer -> {
				// 取出body中的参数
				byte[] bytes = new byte[dataBuffer.readableByteCount()];
				dataBuffer.read(bytes);
				DataBufferUtils.release(dataBuffer);
				Flux<DataBuffer> cachedFlux = Flux.defer(() -> {
					DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
					DataBufferUtils.retain(buffer);
					return Mono.just(buffer);
				});
				
				String bodyString = new String(bytes, StandardCharsets.UTF_8);
				this.sqlFilter(contentType, isSqlInjection, bodyString);
				//  如果存在sql注入,直接拦截请求
				if (isSqlInjection.get()) {
					log.error("SqlInjectionFilter {} - [{}] 参数：{}, 包含不允许sql的关键词，请求拒绝", method, requestPath, bodyString);
					return errorResponse(exchange,"请求参数中包含不允许sql的关键词，请求拒绝");
				}
				// 重新包装ServerHttpRequest
				ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
					@NonNull
					@Override
					public Flux<DataBuffer> getBody() {
						return cachedFlux;
					}
				};
				// 用新的ServerHttpRequest改变交换对象
				ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();
				// 使用默认的messageReaders读取正文字符串
				return ServerRequest.create(mutatedExchange, HandlerStrategies.withDefaults().messageReaders()).bodyToMono(String.class).doOnNext(objectValue ->
						log.info("SqlInjectionFilter end requestPath：{}, method：{}, contentType：{}", requestPath, method, contentType))
						.then(chain.filter(mutatedExchange));
			});
		}
		log.info("SqlInjectionFilter end requestPath：{}, method：{}, contentType：{}", requestPath, method, null);
		return chain.filter(exchange);
	}
	
	/**
	 * sql 过滤
	 * @param contentType 请求上下文类型
	 * @param isSqlInjection 日志过滤
	 * @param bodyString 参数
	 */
	private void sqlFilter(MediaType contentType,AtomicBoolean isSqlInjection,String bodyString){
		if (MediaType.APPLICATION_JSON.isCompatibleWith(contentType)) {
			if (StringUtils.hasText(bodyString) && checkJsonBody(bodyString)) {
				isSqlInjection.set(true);
			}
		} else if (MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(contentType)) {
			if (StringUtils.hasText(bodyString) && checkFormUrlencoded(bodyString)) {
				isSqlInjection.set(true);
			}
		} else if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType) && (StringUtils.hasText(bodyString) && checkFormData(bodyString))) {
				isSqlInjection.set(true);
		}
	}
	
	/**
	 * 检测multipart/form-data是否包含SQL注入关键字
	 *
	 * @param bodyString 请求体
	 * @return 是否包含SQL注入关键字
	 */
	public boolean checkFormData(String bodyString) {
		bodyString = bodyString.replace("\r", "");
		String[] parts = bodyString.split("\n");
		for (int i = 0; i < parts.length; i++) {
			String part = parts[i];
			
			if (part.contains("Content-Disposition: form-data;")) {
				// 文件类型不检测
				if (part.contains("file")) {
					log.info("SqlInjectionFilter file skip");
					return false;
				}
				String value = parts[i + 2];
				if (CharSequenceUtil.isBlank(value)) {
					continue;
				}
				if (checkSqlInjection(value)) {
					return true;
				}
			}
		}
		return false;
	}
	
	/**
	 * 检测application/x-www-form-urlencoded是否包含SQL注入关键字
	 *
	 * @param bodyString 请求体
	 * @return 是否包含SQL注入关键字
	 */
	private boolean checkFormUrlencoded(String bodyString) {
		String[] params = bodyString.split("&");
		for (String param : params) {
			String[] keyValue = param.split("=");
			if (keyValue.length == 2 &&
					CharSequenceUtil.isNotBlank(keyValue[1]) &&
					checkSqlInjection(keyValue[1])) {
				// 判断是否为空
				return true;
			}
		}
		return false;
	}
	
	/**
	 * 检测application/json是否包含SQL注入关键字
	 *
	 * @param body 请求体
	 * @return 是否包含SQL注入关键字
	 */
	private boolean checkJsonBody(String body) {
		try {
			ObjectMapper obj = new ObjectMapper();
			JsonNode rootNode = obj.readTree(body);
			return checkJsonNode(rootNode);
		} catch (IOException e) {
			log.error("SqlInjectionFilter Error while parsing JSON body", e);
			return false;
		}
	}
	
	
	private boolean checkJsonNode(JsonNode node) {
		if (node.isValueNode()) {
			return checkSqlInjection(node.asText());
		} else if (node.isObject()) {
			Iterator<Map.Entry<String, JsonNode>> fieldsIterator = node.fields();
			while (fieldsIterator.hasNext()) {
				Map.Entry<String, JsonNode> field = fieldsIterator.next();
				// 判断是否为空
				if (!field.getValue().isNull() && checkJsonNode(field.getValue())) {
					return true;
				}
			}
		} else if (node.isArray()) {
			for (JsonNode childNode : node) {
				if (checkJsonNode(childNode)) {
					return true;
				}
			}
		}
		return false;
	}
	
	/**
	 * 判断输入的字符串是否包含SQL注入
	 *
	 * @param str 输入的字符串
	 * @return 如果输入的字符串包含SQL注入，返回 true，否则返回 false。
	 */
	private boolean checkSqlInjection(String str) {
		str = str.toLowerCase();
		Matcher matcher = sqlPattern.matcher(str);
		if (matcher.find()) {
			log.error("SqlInjectionFilter 参数[{}]中包含不允许sql的关键词", str);
			return true;
		}
		return false;
	}
	
	/**
	 * 返回错误响应
	 * @param exchange 响应
	 * @return Mono<Void>
	 */
	private Mono<Void> errorResponse(ServerWebExchange exchange, String msg) {
		log.error("[鉴权异常处理]请求路径:{}", exchange.getRequest().getPath());
		return ServletUtils.webFluxResponseWriter(exchange.getResponse(), msg, cn.js189.common.constants.HttpStatus.SQL_FILTER_ERROR);
	}
	
	@Override
	public int getOrder() {
		// 值越小，越先执行
		return -200;
	}
	
}
